{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import numpy\n",
    "import time\n",
    "from sklearn.model_selection import cross_validate\n",
    "from sklearn.model_selection import cross_val_score\n",
    "import numpy as np\n",
    "from sklearn.model_selection import LeaveOneOut\n",
    "import math\n",
    "import numpy.linalg as LA\n",
    "from sklearn.cluster import KMeans\n",
    "import random\n",
    "from sklearn import tree\n",
    "from sklearn.ensemble import GradientBoostingClassifier\n",
    "from sklearn import model_selection,metrics\n",
    "from sklearn.model_selection import GridSearchCV\n",
    "import matplotlib.pylab as plt\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.metrics import roc_curve, auc\n",
    "from sklearn.model_selection import StratifiedKFold\n",
    "from scipy import interp\n",
    "from sklearn.preprocessing import OneHotEncoder\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "random.sample\n",
    "import pandas as pd\n",
    "from sklearn.model_selection import KFold\n",
    "from tqdm import tqdm\n",
    "from sklearn.metrics import confusion_matrix\n",
    "from sklearn.metrics import roc_auc_score, auc\n",
    "from sklearn.metrics import precision_recall_fscore_support\n",
    "from sklearn.metrics import precision_recall_curve\n",
    "from sklearn.metrics import classification_report\n",
    "from collections import Counter\n",
    "\n",
    "import time\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.metrics import roc_auc_score, roc_curve, auc\n",
    "from sklearn.metrics import average_precision_score, precision_recall_curve\n",
    "from sklearn.model_selection import train_test_split,StratifiedKFold\n",
    "import matplotlib.pyplot as plt\n",
    "from tflearn.activations import relu\n",
    "from optparse import OptionParser\n",
    "from scipy import interp\n",
    "from tqdm import trange\n",
    "import tensorflow as tf\n",
    "%matplotlib notebook"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(directory, random_seed):\n",
    "    all_associations = pd.read_csv(directory + '/all_gpe_pairs.csv')\n",
    "    known_associations = all_associations.loc[all_associations['label'] == 1]\n",
    "    print(len(known_associations))\n",
    "    peco_ids = list(set(known_associations['peco_idx']))\n",
    "    unknown_associations = all_associations.loc[all_associations['label'] == 0]\n",
    "    sample_df = known_associations\n",
    "    for peco_id in peco_ids:\n",
    "        random_negative = unknown_associations.loc[all_associations['peco_idx'] == peco_id].sample(n=known_associations.loc[all_associations['peco_idx'] == peco_id].shape[0], random_state=random_seed, axis=0, replace=True)\n",
    "        print(len(random_negative))\n",
    "        sample_df = pd.concat([sample_df,random_negative], axis=0)\n",
    "\n",
    "    sample_df.reset_index(drop=True, inplace=True)\n",
    "\n",
    "    return sample_df\n",
    "\n",
    "def obtain_data(directory, isbalance):\n",
    "\n",
    "    if isbalance:\n",
    "        dtp = sample(directory, random_seed = 1234)\n",
    "    else:\n",
    "        dtp = pd.read_csv(directory + '/all_gene_peco_pairs.csv')\n",
    "\n",
    "    gene_ids = list(set(dtp['gene_idx']))\n",
    "    peco_ids = list(set(dtp['peco_idx']))\n",
    "    random.shuffle(gene_ids)\n",
    "    random.shuffle(peco_ids)\n",
    "    print('# gene = {} | peco = {}'.format(len(gene_ids), len(peco_ids)))\n",
    "\n",
    "    gene_test_num = int(len(gene_ids) / 5)\n",
    "    peco_test_num = int(len(peco_ids) / 5)\n",
    "    print('# Test: gene = {} | peco = {}'.format(gene_test_num, peco_test_num))    \n",
    "    \n",
    "    return dtp, gene_ids, peco_ids, gene_test_num, peco_test_num\n",
    "\n",
    "def generate_task_Tp_train_test_idx(samples):\n",
    "    kf = KFold(n_splits = 5, shuffle = True, random_state = 1234)\n",
    "\n",
    "    train_index_all, test_index_all, n = [], [], 0\n",
    "    train_id_all, test_id_all = [], []\n",
    "    fold = 0\n",
    "    for train_idx, test_idx in tqdm(kf.split(samples.iloc[:, 3:])): #train_index与test_index为下标\n",
    "        print('-------Fold ', fold)\n",
    "        train_index_all.append(train_idx) \n",
    "        test_index_all.append(test_idx)\n",
    "\n",
    "        train_id_all.append(np.array(dtp.iloc[train_idx][['gene_idx', 'peco_idx']]))\n",
    "        test_id_all.append(np.array(dtp.iloc[test_idx][['gene_idx', 'peco_idx']]))\n",
    "\n",
    "        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))\n",
    "        fold += 1\n",
    "    return train_index_all, test_index_all, train_id_all, test_id_all\n",
    "\n",
    "def generate_task_Tg_Tpe_train_test_idx(item, ids, dtp):\n",
    "    \n",
    "    test_num = int(len(ids) / 5)\n",
    "    \n",
    "    train_index_all, test_index_all = [], []\n",
    "    train_id_all, test_id_all = [], []\n",
    "    \n",
    "    for fold in range(5):\n",
    "        print('-------Fold ', fold)\n",
    "        if fold != 4:\n",
    "            test_ids = ids[fold * test_num : (fold + 1) * test_num]\n",
    "        else:\n",
    "            test_ids = ids[fold * test_num :]\n",
    "\n",
    "        train_ids = list(set(ids) ^ set(test_ids))\n",
    "        print('# {}: Train = {} | Test = {}'.format(item, len(train_ids), len(test_ids)))\n",
    "\n",
    "        test_idx = dtp[dtp[item].isin(test_ids)].index.tolist()\n",
    "        train_idx = dtp[dtp[item].isin(train_ids)].index.tolist()\n",
    "        random.shuffle(test_idx)\n",
    "        random.shuffle(train_idx)\n",
    "        print('# Pairs: Train = {} | Test = {}'.format(len(train_idx), len(test_idx)))\n",
    "        assert len(train_idx) + len(test_idx) == len(dtp)\n",
    "\n",
    "        train_index_all.append(train_idx) \n",
    "        test_index_all.append(test_idx)\n",
    "        \n",
    "        train_id_all.append(np.array(dtp.iloc[train_idx][['gene_idx', 'peco_idx']]))\n",
    "        test_id_all.append(np.array(dtp.iloc[test_idx][['gene_idx', 'peco_idx']]))\n",
    "        \n",
    "    return train_index_all, test_index_all, train_id_all, test_id_all\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [],
   "source": [
    "parser = OptionParser()\n",
    "parser.add_option(\"-d\", \"--d\", default=1024, help=\"The embedding dimension d\")\n",
    "parser.add_option(\"-n\", \"--n\", default=1, help=\"global norm to be clipped\")\n",
    "parser.add_option(\"-k\", \"--k\", default=1, help=\"The dimension of project matrices k\")\n",
    "parser.add_option(\"-t\", \"--t\", default=\"o\", help=\"Test scenario\")\n",
    "(opts, args) = parser.parse_args(args = [])\n",
    "\n",
    "def check_symmetric(a, tol=1e-8):\n",
    "    return np.allclose(a, a.T, atol=tol)\n",
    "\n",
    "def row_normalize(a_matrix, substract_self_loop):\n",
    "    if substract_self_loop == True:\n",
    "        np.fill_diagonal(a_matrix,0)\n",
    "    a_matrix = a_matrix.astype(float)\n",
    "    row_sums = a_matrix.sum(axis=1)+1e-12\n",
    "    new_matrix = a_matrix / row_sums[:, np.newaxis]\n",
    "    new_matrix[np.isnan(new_matrix) | np.isinf(new_matrix)] = 0.0\n",
    "    return new_matrix\n",
    "\n",
    "def weight_variable(shape):\n",
    "    initial = tf.compat.v1.truncated_normal(shape, stddev=0.1,dtype=tf.float32)\n",
    "    return tf.Variable(initial, dtype=tf.float32)\n",
    "\n",
    "def bias_variable(shape):\n",
    "    initial = tf.constant(0.1, shape=shape,dtype=tf.float32)\n",
    "    return tf.Variable(initial, dtype=tf.float32)\n",
    "\n",
    "def a_layer(x,units):\n",
    "    W = weight_variable([x.get_shape().as_list()[1],units])\n",
    "    b = bias_variable([units])\n",
    "    tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(W))\n",
    "    return relu(tf.matmul(x, W) + b)\n",
    "\n",
    "\n",
    "def bi_layer(x0,x1,sym,dim_pred):\n",
    "    if sym == False:\n",
    "        W0p = weight_variable([x0.get_shape().as_list()[1],dim_pred])\n",
    "        W1p = weight_variable([x1.get_shape().as_list()[1],dim_pred])\n",
    "        tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(W0p))\n",
    "        tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(W1p))\n",
    "        return tf.matmul(tf.matmul(x0, W0p), \n",
    "                            tf.matmul(x1, W1p),transpose_b=True)\n",
    "    else:\n",
    "        W0p = weight_variable([x0.get_shape().as_list()[1],dim_pred])\n",
    "        tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(W0p))\n",
    "        return tf.matmul(tf.matmul(x0, W0p), \n",
    "                            tf.matmul(x1, W0p),transpose_b=True)\n",
    "\n",
    "\n",
    "# load network\n",
    "network_path = r'../../data/'\n",
    "save_path = '../result/'\n",
    "gene_name = pd.read_excel(network_path + 'gene_name.xlsx')\n",
    "peco_similarity_matrix = np.loadtxt(network_path + 'PSSM.txt')\n",
    "gene_similarity_matrix = np.loadtxt(network_path + 'GSSM.txt')\n",
    "gene_peco_associations = np.loadtxt(network_path + 'GPmat.txt',dtype = float)\n",
    "\n",
    "GPA = gene_peco_associations.T\n",
    "PS = row_normalize(peco_similarity_matrix, True)\n",
    "GS = row_normalize(gene_similarity_matrix, True)\n",
    "\n",
    "[num_peco, num_gene] = GPA.shape\n",
    "dim_peco = int(opts.d)\n",
    "dim_gene = int(opts.d)\n",
    "dim_pred = int(opts.k)\n",
    "dim_pass = int(opts.d)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Model(object):\n",
    "    def __init__(self):\n",
    "        self._build_model()\n",
    "\n",
    "    def _build_model(self):\n",
    "        self.peco_peco = tf.compat.v1.placeholder(tf.float32, [num_peco, num_peco])\n",
    "        self.peco_peco_normalize = tf.compat.v1.placeholder(tf.float32, [num_peco, num_peco])\n",
    "        self.gene_gene = tf.compat.v1.placeholder(tf.float32, [num_gene, num_gene])\n",
    "        self.gene_gene_normalize = tf.compat.v1.placeholder(tf.float32, [num_gene, num_gene])\n",
    "        self.gene_peco = tf.compat.v1.placeholder(tf.float32, [num_gene, num_peco])\n",
    "        self.gene_peco_normalize = tf.compat.v1.placeholder(tf.float32, [num_gene, num_peco])\n",
    "        self.peco_gene = tf.compat.v1.placeholder(tf.float32, [num_peco, num_gene])\n",
    "        self.peco_gene_normalize = tf.compat.v1.placeholder(tf.float32, [num_peco, num_gene])\n",
    "        self.peco_gene_mask = tf.compat.v1.placeholder(tf.float32, [num_peco + num_gene, num_peco + num_gene])\n",
    "\n",
    "        self.A = tf.concat([tf.concat([self.peco_peco, self.peco_gene], 1),\n",
    "                            tf.concat([self.gene_peco, self.gene_gene], 1)], 0)\n",
    "\n",
    "        self.peco_embedding = weight_variable([num_peco, dim_peco])\n",
    "        self.gene_embedding = weight_variable([num_gene, dim_gene])\n",
    "        tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(self.peco_embedding))\n",
    "        tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(self.gene_embedding))\n",
    "\n",
    "        W0 = weight_variable([dim_pass + dim_peco, dim_peco])\n",
    "        b0 = bias_variable([dim_peco])\n",
    "        tf.compat.v1.add_to_collection('l2_reg', tf.keras.regularizers.L2(1.0)(W0))\n",
    "\n",
    "        peco_vector0 = tf.compat.v1.nn.l2_normalize(relu(tf.matmul(\n",
    "            tf.concat([tf.matmul(self.peco_gene_normalize,\n",
    "                                 a_layer(self.gene_embedding, dim_pass)) +\n",
    "                       tf.matmul(self.peco_peco_normalize,\n",
    "                                 a_layer(self.peco_embedding, dim_pass)),\n",
    "                       self.peco_embedding], axis=1), W0) + b0), dim=1)\n",
    "\n",
    "        gene_vector0 = tf.compat.v1.nn.l2_normalize(relu(tf.matmul(\n",
    "            tf.concat([tf.matmul(self.gene_gene_normalize,\n",
    "                                 a_layer(self.gene_embedding, dim_pass)) +\n",
    "                       tf.matmul(self.gene_peco_normalize,\n",
    "                                 a_layer(self.peco_embedding, dim_pass)),\n",
    "                       self.gene_embedding], axis=1), W0) + b0), dim=1)\n",
    "\n",
    "        self.peco_representation = peco_vector0\n",
    "        self.gene_representation = gene_vector0\n",
    "        self.features_matrix = tf.concat([self.peco_representation, self.gene_representation], 0)\n",
    "        self.A_reconstruct = bi_layer(self.features_matrix, self.features_matrix, sym=True, dim_pred=dim_pred)\n",
    "\n",
    "        tmp = tf.multiply(self.peco_gene_mask, (self.A_reconstruct - self.A))\n",
    "        self.A_reconstruct_loss = tf.reduce_sum(tf.multiply(tmp, tmp))\n",
    "        self.l2_loss = tf.add_n(tf.compat.v1.get_collection(\"l2_reg\"))\n",
    "        self.loss = self.A_reconstruct_loss + self.l2_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "graph = tf.compat.v1.get_default_graph()\n",
    "with graph.as_default():\n",
    "    model = Model()\n",
    "    learning_rate = tf.compat.v1.placeholder(tf.float32, [])\n",
    "    total_loss = model.loss\n",
    "\n",
    "    optimize = tf.compat.v1.train.AdamOptimizer(learning_rate)\n",
    "    gradients, variables = zip(*optimize.compute_gradients(total_loss))\n",
    "    gradients, _ = tf.clip_by_global_norm(gradients, int(opts.n))\n",
    "    optimizer = optimize.apply_gradients(zip(gradients, variables))\n",
    "\n",
    "    DR_pred = model.A_reconstruct[:num_peco, num_peco:]\n",
    "    RD_pred = model.A_reconstruct[num_peco:, :num_peco]\n",
    "    eval_pred = (DR_pred + tf.transpose(RD_pred, perm=[1,0])) / 2.0\n",
    "\n",
    "\n",
    "def divide_known_unknown_associations(A, exception=None, special=None):\n",
    "    known = []\n",
    "    unknown = []\n",
    "    if special != None:\n",
    "        for j in range(A.shape[1]):\n",
    "            if A[special][j] == 1:\n",
    "                known.append([special,j,1])\n",
    "            else:\n",
    "                unknown.append([special,j,0])\n",
    "    else:\n",
    "        for i in range(A.shape[0]):\n",
    "            if i == exception: pass\n",
    "            for j in range(A.shape[1]):\n",
    "                if A[i][j] == 1:\n",
    "                    known.append([i,j,1])\n",
    "                else:\n",
    "                    unknown.append([i,j,0])\n",
    "\n",
    "    return np.array(known), np.array(unknown)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "def performances(y_true, y_pred, y_prob):\n",
    "\n",
    "    tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels = [0, 1]).ravel().tolist()\n",
    "\n",
    "    accuracy = (tp+tn)/(tn+fp+fn+tp)\n",
    "    \n",
    "    if tp+fn == 0:\n",
    "        recall = 0\n",
    "    else:\n",
    "        recall = tp / (tp+fn)\n",
    "    \n",
    "    if tp+fp == 0:\n",
    "        precision = 0\n",
    "    else:\n",
    "        precision = tp / (tp+fp)\n",
    "    \n",
    "    if precision + recall == 0:\n",
    "        f1 = 0\n",
    "    else:\n",
    "        f1 = 2*precision*recall / (precision+recall)\n",
    "    \n",
    "    roc_auc = roc_auc_score(y_true, y_prob)\n",
    "    prec, reca, _ = precision_recall_curve(y_true, y_prob)\n",
    "    aupr = auc(reca, prec)\n",
    "    \n",
    "    print('tn = {}, fp = {}, fn = {}, tp = {}'.format(tn, fp, fn, tp))\n",
    "    print('y_pred: 0 = {} | 1 = {}'.format(Counter(y_pred)[0], Counter(y_pred)[1]))\n",
    "    print('y_true: 0 = {} | 1 = {}'.format(Counter(y_true)[0], Counter(y_true)[1]))\n",
    "    print('acc={:.4f}|precision={:.4f}|recall={:.4f}|f1={:.4f}|auc={:.4f}|aupr={:.4f}'.format(accuracy, precision, recall, f1, roc_auc, aupr))\n",
    "    return (y_true, y_pred, y_prob), (accuracy, precision, recall, f1, roc_auc, aupr)\n",
    "\n",
    "def transfer(y_pred):\n",
    "    return [[0,1][x>0.5] for x in y_pred.reshape(-1)]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_and_evaluate(GPAtrain, GPAtest, graph, verbose=True, num_steps = 2000):\n",
    "    mask = np.zeros((num_peco,num_gene))\n",
    "    peco_gene = np.zeros((num_peco,num_gene))\n",
    "    print(num_peco)\n",
    "    print(num_gene)\n",
    "    for ele in GPAtrain:\n",
    "        peco_gene[ele[0],ele[1]] = ele[2]\n",
    "        mask[ele[0],ele[1]] = 1\n",
    "    gene_peco = peco_gene.T\n",
    "    peco_gene_normalize = row_normalize(peco_gene,False)\n",
    "    gene_peco_normalize = row_normalize(gene_peco,False)\n",
    "    mask = np.concatenate([np.concatenate([np.ones((num_peco, num_peco)), mask], axis=1),\n",
    "                           np.concatenate([mask.T, np.ones((num_gene, num_gene))], axis=1)], axis=0)\n",
    "\n",
    "\n",
    "    lr = 0.005\n",
    "    scores = []\n",
    "    labels = []\n",
    "    min_loss = float('inf')\n",
    "    with tf.compat.v1.Session(graph=graph) as sess:\n",
    "        tf.compat.v1.global_variables_initializer().run()\n",
    "        for i in trange(num_steps):\n",
    "            _, tloss, results = sess.run([optimizer,total_loss,eval_pred],\n",
    "                                        feed_dict={model.peco_gene:peco_gene, model.peco_gene_normalize:peco_gene_normalize,\n",
    "                                        model.gene_peco:gene_peco, model.gene_peco_normalize:gene_peco_normalize,\n",
    "                                        model.peco_peco:peco_similarity_matrix, model.gene_gene:gene_similarity_matrix,\n",
    "                                        model.peco_peco_normalize:PS, model.gene_gene_normalize:GS,\n",
    "                                        model.peco_gene_mask:mask, learning_rate: lr})\n",
    "            #every 20 steps of gradient descent, evaluate the performance, other choices of this number are possible\n",
    "            if i % 20 == 0 and verbose == True:\n",
    "#                 print('step',i,'total loss',tloss)\n",
    "\n",
    "                if tloss <= min_loss:\n",
    "                    min_loss = tloss\n",
    "                    best_results = results\n",
    "\n",
    "\n",
    "        for ele in GPAtest:\n",
    "            scores.append(best_results[ele[0], ele[1]])\n",
    "            labels.append(ele[2])\n",
    "        \n",
    "\n",
    "    return scores, labels\n",
    "\n",
    "def transfer(y_pred):\n",
    "    return [[0,1][x>0.5] for x in y_pred]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tp, balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "5it [00:00, 119.12it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# gene = 11177 | peco = 24\n",
      "# Test: gene = 2235 | peco = 4\n",
      "-------Fold  0\n",
      "# Pairs: Train = 37692 | Test = 9424\n",
      "-------Fold  1\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "-------Fold  2\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "-------Fold  3\n",
      "# Pairs: Train = 37693 | Test = 9423\n",
      "-------Fold  4\n",
      "# Pairs: Train = 37693 | Test = 9423\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "sample(r'../../data/', random_seed = 1234)\n",
    "directory = r'../../data'\n",
    "dtp, gene_ids, peco_ids, gene_test_num, peco_test_num = obtain_data(directory, isbalance = True)\n",
    "train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tp_train_test_idx(dtp)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tp, balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "epochs = 2000\n",
    "\n",
    "tpr_fold = []\n",
    "precision_fold = []\n",
    "# GPAcandidate = unknown_associations\n",
    "for train_idx, test_idx in zip(train_index_all, test_index_all):\n",
    "    GPAtrain = np.array(dtp.iloc[train_idx])\n",
    "    GPAtest = np.array(dtp.iloc[test_idx])\n",
    "\n",
    "    '''\n",
    "    # used to adjust the superparameter\n",
    "    DMAtrain, DMAvalid = train_test_split(DMAtrain, test_size=0.05, random_state=rs)\n",
    "    scores, labels = train_and_evaluate(DMAtrain=DMAtrain, DMAtest=DMAvalid, DMAcandidate=DMAcandidate, \n",
    "                                        graph=graph, num_steps=epochs)\n",
    "    '''\n",
    "    scores, labels = train_and_evaluate(GPAtrain=GPAtrain, GPAtest=GPAtest,\n",
    "                                        graph=graph, num_steps=epochs)\n",
    "    y_test_prob = scores\n",
    "    y_test_pred = transfer(y_test_prob)\n",
    "    performances_test = performances(labels, y_test_pred, y_test_prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "# Tg, balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "# gene = 11177 | peco = 24\n",
      "# Test: gene = 2235 | peco = 4\n",
      "-------Fold  0\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37625 | Test = 9491\n",
      "-------Fold  1\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37721 | Test = 9395\n",
      "-------Fold  2\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37708 | Test = 9408\n",
      "-------Fold  3\n",
      "# gene_idx: Train = 8942 | Test = 2235\n",
      "# Pairs: Train = 37619 | Test = 9497\n",
      "-------Fold  4\n",
      "# gene_idx: Train = 8940 | Test = 2237\n",
      "# Pairs: Train = 37791 | Test = 9325\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [8:00:53<00:00, 14.43s/it]\n",
      "<ipython-input-26-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 4318, fp = 427, fn = 1639, tp = 3107\n",
      "y_pred: 0 = 5957 | 1 = 3534\n",
      "y_true: 0 = 4745 | 1 = 4746\n",
      "acc=0.7823|precision=0.8792|recall=0.6547|f1=0.7505|auc=0.8773|aupr=0.8901\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [7:58:31<00:00, 14.36s/it]\n",
      "<ipython-input-26-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 4184, fp = 470, fn = 1474, tp = 3267\n",
      "y_pred: 0 = 5658 | 1 = 3737\n",
      "y_true: 0 = 4654 | 1 = 4741\n",
      "acc=0.7931|precision=0.8742|recall=0.6891|f1=0.7707|auc=0.8969|aupr=0.9053\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [7:42:38<00:00, 13.88s/it]\n",
      "<ipython-input-26-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 4093, fp = 595, fn = 1518, tp = 3202\n",
      "y_pred: 0 = 5611 | 1 = 3797\n",
      "y_true: 0 = 4688 | 1 = 4720\n",
      "acc=0.7754|precision=0.8433|recall=0.6784|f1=0.7519|auc=0.8615|aupr=0.8728\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [6:57:00<00:00, 12.51s/it]\n",
      "<ipython-input-26-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 4188, fp = 596, fn = 1536, tp = 3177\n",
      "y_pred: 0 = 5724 | 1 = 3773\n",
      "y_true: 0 = 4784 | 1 = 4713\n",
      "acc=0.7755|precision=0.8420|recall=0.6741|f1=0.7488|auc=0.8539|aupr=0.8745\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 2000/2000 [6:44:08<00:00, 12.12s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 4193, fp = 494, fn = 1596, tp = 3042\n",
      "y_pred: 0 = 5789 | 1 = 3536\n",
      "y_true: 0 = 4687 | 1 = 4638\n",
      "acc=0.7759|precision=0.8603|recall=0.6559|f1=0.7443|auc=0.8618|aupr=0.8721\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-26-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    }
   ],
   "source": [
    "directory = r'D:/小麦/MDA-GCNFTG-main/MDA-GCNFTG-main/data'\n",
    "dtp, gene_ids, peco_ids, gene_test_num, peco_test_num = obtain_data(directory, isbalance = True)\n",
    "item = 'gene_idx'\n",
    "ids = gene_ids\n",
    "train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tg_Tpe_train_test_idx(item, ids, dtp)\n",
    "\n",
    "epochs = 2000\n",
    "\n",
    "tpr_fold = []\n",
    "precision_fold = []\n",
    "mean_fpr = np.linspace(0, 1, 100)\n",
    "mean_recall = np.linspace(0, 1, 100)\n",
    "# GPAcandidate = unknown_associations\n",
    "for train_idx, test_idx in zip(train_index_all, test_index_all):\n",
    "    GPAtrain = np.array(dtp.iloc[train_idx])\n",
    "    GPAtest = np.array(dtp.iloc[test_idx])\n",
    "\n",
    "    \"\"\"\n",
    "    # used to adjust the superparameter\n",
    "    DMAtrain, DMAvalid = train_test_split(DMAtrain, test_size=0.05, random_state=rs)\n",
    "    scores, labels = train_and_evaluate(DMAtrain=DMAtrain, DMAtest=DMAvalid, DMAcandidate=DMAcandidate, \n",
    "                                        graph=graph, num_steps=epochs)\n",
    "    \"\"\"\n",
    "    scores, labels = train_and_evaluate(GPAtrain=GPAtrain, GPAtest=GPAtest,\n",
    "                                        graph=graph, num_steps=epochs)\n",
    "    \n",
    "    y_test_prob = scores\n",
    "    y_test_pred = transfer(y_test_prob)\n",
    "    performances_test = performances(labels, y_test_pred, y_test_prob)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Tpe, balance"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "23558\n",
      "1528\n",
      "10396\n",
      "86\n",
      "121\n",
      "1179\n",
      "126\n",
      "5\n",
      "14\n",
      "45\n",
      "4992\n",
      "7\n",
      "67\n",
      "312\n",
      "853\n",
      "148\n",
      "29\n",
      "831\n",
      "45\n",
      "476\n",
      "2239\n",
      "1\n",
      "51\n",
      "1\n",
      "6\n",
      "# gene = 11177 | peco = 24\n",
      "# Test: gene = 2235 | peco = 4\n",
      "-------Fold  0\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 40644 | Test = 6472\n",
      "-------Fold  1\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 46394 | Test = 722\n",
      "-------Fold  2\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 44006 | Test = 3110\n",
      "-------Fold  3\n",
      "# peco_idx: Train = 20 | Test = 4\n",
      "# Pairs: Train = 43714 | Test = 3402\n",
      "-------Fold  4\n",
      "# peco_idx: Train = 16 | Test = 8\n",
      "# Pairs: Train = 13706 | Test = 33410\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [06:00<00:00, 12.01s/it]\n",
      "<ipython-input-14-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 45, fp = 3191, fn = 49, tp = 3187\n",
      "y_pred: 0 = 94 | 1 = 6378\n",
      "y_true: 0 = 3236 | 1 = 3236\n",
      "acc=0.4994|precision=0.4997|recall=0.9849|f1=0.6630|auc=0.5192|aupr=0.5418\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [05:43<00:00, 11.44s/it]\n",
      "<ipython-input-14-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 0, fp = 361, fn = 0, tp = 361\n",
      "y_pred: 0 = 0 | 1 = 722\n",
      "y_true: 0 = 361 | 1 = 361\n",
      "acc=0.5000|precision=0.5000|recall=1.0000|f1=0.6667|auc=0.6736|aupr=0.6157\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [05:43<00:00, 11.45s/it]\n",
      "<ipython-input-14-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 0, fp = 1555, fn = 2, tp = 1553\n",
      "y_pred: 0 = 2 | 1 = 3108\n",
      "y_true: 0 = 1555 | 1 = 1555\n",
      "acc=0.4994|precision=0.4997|recall=0.9987|f1=0.6661|auc=0.6203|aupr=0.5835\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [05:53<00:00, 11.79s/it]\n",
      "<ipython-input-14-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 1698, fp = 3, fn = 1700, tp = 1\n",
      "y_pred: 0 = 3398 | 1 = 4\n",
      "y_true: 0 = 1701 | 1 = 1701\n",
      "acc=0.4994|precision=0.2500|recall=0.0006|f1=0.0012|auc=0.5890|aupr=0.5763\n",
      "32\n",
      "12187\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████| 30/30 [05:48<00:00, 11.60s/it]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tn = 11179, fp = 5526, fn = 11096, tp = 5609\n",
      "y_pred: 0 = 22275 | 1 = 11135\n",
      "y_true: 0 = 16705 | 1 = 16705\n",
      "acc=0.5025|precision=0.5037|recall=0.3358|f1=0.4029|auc=0.4887|aupr=0.4902\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<ipython-input-14-bff51fc9c44e>:47: DeprecationWarning: In future, it will be an error for 'np.bool_' scalars to be interpreted as an index\n",
      "  return [[0,1][x>0.5] for x in y_pred]\n"
     ]
    }
   ],
   "source": [
    "directory = r'D:/小麦/MDA-GCNFTG-main/MDA-GCNFTG-main/data'\n",
    "dtp, gene_ids, peco_ids, gene_test_num, peco_test_num = obtain_data(directory, isbalance = True)\n",
    "item = 'peco_idx'\n",
    "ids = peco_ids\n",
    "train_index_all, test_index_all, train_id_all, test_id_all = generate_task_Tg_Tpe_train_test_idx(item, ids, dtp)\n",
    "\n",
    "epochs = 30\n",
    "\n",
    "for train_idx, test_idx in zip(train_index_all, test_index_all):\n",
    "    GPAtrain = np.array(dtp.iloc[train_idx])\n",
    "    GPAtest = np.array(dtp.iloc[test_idx])\n",
    "\n",
    "    \"\"\"\n",
    "    # used to adjust the superparameter\n",
    "    GPAtrain, GPAvalid = train_test_split(GPAtrain, test_size=0.05, random_state=1234)\n",
    "    scores, labels = train_and_evaluate(GPAtrain=GPAtrain, GPAtest=GPAvalid, GPAcandidate=GPAcandidate, \n",
    "                                        graph=graph, num_steps=epochs)\n",
    "    \"\"\"\n",
    "    scores, labels = train_and_evaluate(GPAtrain=GPAtrain, GPAtest=GPAtest,\n",
    "                                        graph=graph, num_steps=epochs)\n",
    "    \n",
    "    y_test_prob = scores\n",
    "    y_test_pred = transfer(y_test_prob)\n",
    "    performances_test = performances(labels, y_test_pred, y_test_prob)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
